In [1]:
%load_ext autoreload
%autoreload 2

import os
os.chdir("../..")
from pathlib import Path
import logging
import pandas as pd
from prophet import Prophet
from prophet.diagnostics import cross_validation, performance_metrics
from prophet.plot import plot_plotly, plot_components_plotly
from src.config.config import *
from src.features.build_features import run_pipeline
from src.models.train_model import mass_forecaster
import pickle

logging.getLogger("prophet").setLevel(logging.ERROR)
logging.getLogger("cmdstanpy").disabled = True
In [2]:
conf = get_config()
df_train, df_test = run_pipeline(conf.DATA_PATH)
Loading config file: "./conf/config.yaml" 
Feature datasets have been saved!
In [3]:
df_train.head()
Out[3]:
cat__Promo_1.0 cat__SchoolHoliday_1.0 y
Store ds
1 2013-01-02 0.0 1.0 5530
2013-01-03 0.0 1.0 4327
2013-01-04 0.0 1.0 4486
2013-01-05 0.0 1.0 4997
2013-01-07 1.0 1.0 7176
In [4]:
df_test.head()
Out[4]:
cat__Promo_1.0 cat__SchoolHoliday_1.0
Store ds
1 2015-08-01 0.0 1.0
2015-08-02 0.0 1.0
2015-08-03 1.0 1.0
2015-08-04 1.0 1.0
2015-08-05 1.0 1.0

Run cross-validation with backfitting, save best model¶

In [5]:
# param_grid to optimize over
print(conf.PARAM_GRID)
{'changepoint_prior_scale': [0.001, 0.01, 0.1, 0.5], 'seasonality_mode': ['additive', 'multiplicative'], 'seasonality_prior_scale': [0.01, 0.1, 1.0, 10.0]}
In [6]:
# Cross-validate and backfit
mass_forecaster(conf)
Starting forecasting procedure for Store:1
Best params are {'changepoint_prior_scale': 0.1, 'seasonality_mode': 'additive', 'seasonality_prior_scale': 0.01}
Best rmse is : 623.84
Starting forecasting procedure for Store:3
Best params are {'changepoint_prior_scale': 0.01, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 0.01}
Best rmse is : 1023.23
Starting forecasting procedure for Store:7
Best params are {'changepoint_prior_scale': 0.01, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 0.01}
Best rmse is : 1595.29
Starting forecasting procedure for Store:8
Best params are {'changepoint_prior_scale': 0.1, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 0.01}
Best rmse is : 813.30
Starting forecasting procedure for Store:9
Best params are {'changepoint_prior_scale': 0.01, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 10.0}
Best rmse is : 1068.19
Starting forecasting procedure for Store:10
Best params are {'changepoint_prior_scale': 0.5, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 0.01}
Best rmse is : 695.85
Starting forecasting procedure for Store:11
Best params are {'changepoint_prior_scale': 0.1, 'seasonality_mode': 'additive', 'seasonality_prior_scale': 0.01}
Best rmse is : 1278.95
Starting forecasting procedure for Store:12
Best params are {'changepoint_prior_scale': 0.5, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 1.0}
Best rmse is : 1285.59
Starting forecasting procedure for Store:13
Best params are {'changepoint_prior_scale': 0.001, 'seasonality_mode': 'additive', 'seasonality_prior_scale': 0.01}
Best rmse is : 1099.03
Starting forecasting procedure for Store:14
Best params are {'changepoint_prior_scale': 0.01, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 0.1}
Best rmse is : 698.78

Load single saved model and predict¶

In [7]:
import pickle

# stores i.e. 1,3,7...
store = 1

with (conf.MODEL_PATH / "saved_models" / f"{str(store)}.pkl").open("rb") as handle:
    model_1 = pickle.load(handle)

Visualizations¶

In [8]:
plot_plotly(
    model_1,
    model_1.predict(
        pd.concat(
            [df_train.loc[store].reset_index(), df_test.loc[store].reset_index()],
            axis=0,
        )
    ),
)
In [9]:
plot_components_plotly(model_1, model_1.predict(df_test.reset_index()))
In [10]:
## View results from time series backfitting grid search
pd.read_csv("models/results/tuning_results.csv", index_col=[0])
Out[10]:
store changepoint_prior_scale seasonality_mode seasonality_prior_scale rmse
0 1 0.001 additive 0.01 758.632796
1 1 0.001 additive 0.10 799.090963
2 1 0.001 additive 1.00 739.941039
3 1 0.001 additive 10.00 786.382349
4 1 0.001 multiplicative 0.01 782.214540
... ... ... ... ... ...
27 14 0.500 additive 10.00 733.883589
28 14 0.500 multiplicative 0.01 773.275699
29 14 0.500 multiplicative 0.10 736.575138
30 14 0.500 multiplicative 1.00 737.317667
31 14 0.500 multiplicative 10.00 737.817153

320 rows × 5 columns

In [11]:
# View forecasts
pd.read_csv("models/results/forecasts.csv", index_col=[0])
Out[11]:
store ds trend yhat_lower yhat_upper trend_lower trend_upper additive_terms additive_terms_lower additive_terms_upper weekly weekly_lower weekly_upper yearly yearly_lower yearly_upper multiplicative_terms multiplicative_terms_lower multiplicative_terms_upper yhat
0 1 2013-01-02 5284.447944 4437.646265 6615.198120 5284.447944 5284.447944 225.284287 225.284287 225.284287 -168.384685 -168.384685 -168.384685 393.668971 393.668971 393.668971 0.000000 0.000000 0.000000 5509.732231
1 1 2013-01-03 5282.426092 4205.893135 6439.996854 5282.426092 5282.426092 47.561273 47.561273 47.561273 -243.860775 -243.860775 -243.860775 291.422048 291.422048 291.422048 0.000000 0.000000 0.000000 5329.987365
2 1 2013-01-04 5280.404241 4391.347974 6492.928135 5280.404241 5280.404241 168.446565 168.446565 168.446565 -25.282493 -25.282493 -25.282493 193.729058 193.729058 193.729058 0.000000 0.000000 0.000000 5448.850806
3 1 2013-01-05 5278.382389 4433.331538 6531.860868 5278.382389 5278.382389 253.192242 253.192242 253.192242 151.482467 151.482467 151.482467 101.709775 101.709775 101.709775 0.000000 0.000000 0.000000 5531.574631
4 1 2013-01-07 5274.338685 4514.679064 6698.863744 5274.338685 5274.338685 290.639197 290.639197 290.639197 352.115898 352.115898 352.115898 -61.476701 -61.476701 -61.476701 0.000000 0.000000 0.000000 5564.977882
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
43 14 2015-09-13 5691.851317 3896.936149 6524.159405 5691.850807 5691.851818 0.000000 0.000000 0.000000 0.019224 0.019224 0.019224 -0.105064 -0.105064 -0.105064 -0.085840 -0.085840 -0.085840 5203.263804
44 14 2015-09-14 5692.133133 5163.000069 7668.927595 5692.132595 5692.133666 0.000000 0.000000 0.000000 0.234822 0.234822 0.234822 -0.098604 -0.098604 -0.098604 0.136218 0.136218 0.136218 6467.503347
45 14 2015-09-15 5692.414949 4300.793334 6785.913349 5692.414394 5692.415531 0.000000 0.000000 0.000000 0.068748 0.068748 0.068748 -0.091335 -0.091335 -0.091335 -0.022587 -0.022587 -0.022587 5563.837657
46 14 2015-09-16 5692.696765 3926.175141 6464.533322 5692.696186 5692.697366 0.000000 0.000000 0.000000 -0.011180 -0.011180 -0.011180 -0.083386 -0.083386 -0.083386 -0.094565 -0.094565 -0.094565 5154.364468
47 14 2015-09-17 5692.978581 4075.649832 6669.954978 5692.977983 5692.979201 0.000000 0.000000 0.000000 0.004806 0.004806 0.004806 -0.074900 -0.074900 -0.074900 -0.070093 -0.070093 -0.070093 5293.938169

8141 rows × 20 columns